{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n",
        "# CLIP Classify Content of Video\n",
        "\n",
        "---\n",
        "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/roboflow/inference/blob/main/docs/notebooks/clip_classification.ipynb)\n",
        "\n",
        "CLIP is a powerful foundation model for zero-shot classification. In this scenario, we are using CLIP to classify the topics in a Youtube video. Plug in your own video and set of prompts!\n",
        "\n",
        "Click the Open in Colab button to run the cookbook on Google Colab.\n",
        "\n",
        "**Let's begin!**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Install required packages\n",
        "In this cookbook, we'll leverage two Python packages - `opencv` and `supervision`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "!pip install supervision opencv-python"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Imports & Configure Roboflow Inference Server"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 25,
      "metadata": {
        "id": "PNOSIM4xT9s8"
      },
      "outputs": [],
      "source": [
        "import requests\n",
        "import base64\n",
        "from PIL import Image\n",
        "from io import BytesIO\n",
        "import os\n",
        "import supervision as sv\n",
        "from tqdm import tqdm\n",
        "from supervision import get_video_frames_generator\n",
        "import time\n",
        "\n",
        "INFERENCE_ENDPOINT = \"https://infer.roboflow.com\"\n",
        "API_KEY = \"YOUR_API_KEY\"\n",
        "VIDEO = \"VIDEO_PATH\"\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Prompt List for CLIP similarity function"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "#Prompt list to evaluate similarity between each image and each prompt. If something else is selected, then we ignore the caption\n",
        "#change this to your desired prompt list\n",
        "prompt_list = [['action video game shooting xbox','Drake rapper music','soccer game ball',\n",
        "                'marvel combic book','beyonce','Church pope praying',\n",
        "                'Mcdonalds French Fries',\"something else\"]]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## CLIP Endpoint Compare Frame & Prompt List Similarity"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 26,
      "metadata": {
        "id": "YLvdj1eMd2WX"
      },
      "outputs": [],
      "source": [
        "def classify_image(image: str, prompt: str) -> dict:\n",
        "    \n",
        "    image_data = Image.fromarray(image)\n",
        "\n",
        "    buffer = BytesIO()\n",
        "    image_data.save(buffer, format=\"JPEG\")\n",
        "    image_data = base64.b64encode(buffer.getvalue()).decode(\"utf-8\")\n",
        "\n",
        "    payload = {\n",
        "        \"api_key\": API_KEY,\n",
        "        \"subject\": {\n",
        "            \"type\": \"base64\",\n",
        "            \"value\": image_data\n",
        "        },\n",
        "        \"prompt\": prompt,\n",
        "    }\n",
        "\n",
        "    data = requests.post(INFERENCE_ENDPOINT + \"/clip/compare?api_key=\" + API_KEY, json=payload)\n",
        "\n",
        "    response = data.json()\n",
        "    #print(response[\"similarity\"])\n",
        "    sim = response[\"similarity\"]\n",
        "\n",
        "    highest_prediction = 0\n",
        "    highest_prediction_index = 0\n",
        "\n",
        "    for i, prediction in enumerate(response[\"similarity\"]):\n",
        "        if prediction > highest_prediction:\n",
        "            highest_prediction = prediction\n",
        "            highest_prediction_index = i\n",
        "\n",
        "    return prompt[highest_prediction_index], sim[highest_prediction_index]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Process Video & Return Most Similar Prompt to Frame"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "\n",
        "def process_video_frames(video_path, prompt_list, total_frames=160, total_seconds=80, stride_length=30,max_retries):\n",
        "    if not os.path.exists(video_path):\n",
        "        print(f\"The specified video file '{video_path}' does not exist.\")\n",
        "        return\n",
        "\n",
        "    frames_per_second = total_frames / total_seconds\n",
        "    frame_dict = {}\n",
        "\n",
        "    for frame_index, frame in enumerate(sv.get_video_frames_generator(source_path=video_path, stride=stride_length, start=0)):\n",
        "        frame_second = frame_index * (1 / frames_per_second)\n",
        "        frame_key = f\"Frame {frame_index}: {frame_second:.2f} seconds\"\n",
        "        frame_dict[frame_key] = []\n",
        "\n",
        "        print(frame_key)\n",
        "        retries = 0\n",
        "\n",
        "        for prompt in prompt_list:\n",
        "            try: \n",
        "                label, similarity = classify_image(frame)\n",
        "                if label != \"something else\":\n",
        "                    print('label found')\n",
        "                    frame_dict[frame_key].append({label: similarity})\n",
        "                    print('\\n')\n",
        "\n",
        "            except Exception as e:\n",
        "                retries += 1\n",
        "                print(f\"Error: {e}\")\n",
        "                print(f\"Retrying... (Attempt {retries}/{max_retries})\")\n",
        "\n",
        "                if retries >= max_retries:\n",
        "                    print(\"Max retries exceeded. Skipping frame.\")\n",
        "                    break\n",
        "\n",
        "    return frame_dict\n",
        "\n",
        "# Example usage:\n",
        "max_retries = 4\n",
        "prompt_list = prompt_list\n",
        "clip_results = process_video_frames(VIDEO, prompt_list,max_retries)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Create JSON file and filter out low similarity classes"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Flatten the nested dictionary\n",
        "data = clip_results\n",
        "# Define the threshold based on the similarity score returned for the most similar prompt\n",
        "threshold = 0.22\n",
        "\n",
        "# Filter out key-value pairs below the threshold for each frame\n",
        "filtered_data = [\n",
        "    {\n",
        "        frame: [\n",
        "            {key: value}\n",
        "            for item in items\n",
        "            for key, value in item.items()\n",
        "            if value > threshold\n",
        "        ]\n",
        "    }\n",
        "    for frame, items in data.items()\n",
        "]\n",
        "print(filtered_data)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 44,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Specify the filename for the JSON file\n",
        "import json\n",
        "filename = f\"{str(threshold)}.json\"\n",
        "\n",
        "# Write the dictionary to the JSON file\n",
        "with open(filename, 'w') as json_file:\n",
        "    json.dump(filtered_data, json_file, indent=4)  # The indent parameter is optional for pretty-printing\n",
        "\n",
        "#print(f'Data has been written to {filename})"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "![Youtube_Video](https://storage.googleapis.com/com-roboflow-marketing/drake_youtube.png)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "![ClipJson](https://storage.googleapis.com/com-roboflow-marketing/content_classification.png)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.17"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
